{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "LqNpENf-ec0X",
        "slideshow": {
          "slide_type": "slide"
        }
      },
      "outputs": [],
      "source": [
        "!pip install -U tf-nightly"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "Pa2qpEmoVOGe",
        "slideshow": {
          "slide_type": "-"
        }
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import time\n",
        "\n",
        "import tensorflow as tf\n",
        "from tensorflow.contrib import autograph\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import six\n",
        "\n",
        "from google.colab import widgets"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "HNqUFL4deCsL",
        "slideshow": {
          "slide_type": "slide"
        }
      },
      "source": [
        "# Case study: training a custom RNN, using Keras and Estimators\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "YkC1k4HEQ7rw",
        "slideshow": {
          "slide_type": "-"
        }
      },
      "source": [
        "In this section, we show how you can use AutoGraph to build RNNColorbot, an RNN that takes as input names of colors and predicts their corresponding RGB tuples. The model will be trained by a [custom Estimator](https://www.tensorflow.org/get_started/custom_estimators)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7nkPDl5CTCNb",
        "slideshow": {
          "slide_type": "-"
        }
      },
      "source": [
        "To get started, set up the dataset. The following cells defines methods that download and format the data needed for RNNColorbot; the details aren't important (read them in the privacy of your own home if you so wish), but make sure to run the cells before proceeding."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "A0uREmVXCQEw",
        "slideshow": {
          "slide_type": "-"
        }
      },
      "outputs": [],
      "source": [
        "def parse(line):\n",
        "  \"\"\"Parses a line from the colors dataset.\"\"\"\n",
        "  items = tf.string_split([line], \",\").values\n",
        "  rgb = tf.string_to_number(items[1:], out_type=tf.float32) / 255.0\n",
        "  color_name = items[0]\n",
        "  chars = tf.one_hot(tf.decode_raw(color_name, tf.uint8), depth=256)\n",
        "  length = tf.cast(tf.shape(chars)[0], dtype=tf.int64)\n",
        "  return rgb, chars, length\n",
        "\n",
        "\n",
        "def set_static_batch_shape(batch_size):\n",
        "  def apply(rgb, chars, length):\n",
        "    rgb.set_shape((batch_size, None))\n",
        "    chars.set_shape((batch_size, None, 256))\n",
        "    length.set_shape((batch_size,))\n",
        "    return rgb, chars, length\n",
        "  return apply\n",
        "\n",
        "\n",
        "def load_dataset(data_dir, url, batch_size, training=True):\n",
        "  \"\"\"Loads the colors data at path into a tf.PaddedDataset.\"\"\"\n",
        "  path = tf.keras.utils.get_file(os.path.basename(url), url, cache_dir=data_dir)\n",
        "  dataset = tf.data.TextLineDataset(path)\n",
        "  dataset = dataset.skip(1)\n",
        "  dataset = dataset.map(parse)\n",
        "  dataset = dataset.cache()\n",
        "  dataset = dataset.repeat()\n",
        "  if training:\n",
        "    dataset = dataset.shuffle(buffer_size=3000)\n",
        "  dataset = dataset.padded_batch(\n",
        "      batch_size, padded_shapes=((None,), (None, 256), ()))\n",
        "  # To simplify the model code, we statically set as many of the shapes that we\n",
        "  # know.\n",
        "  dataset = dataset.map(set_static_batch_shape(batch_size))\n",
        "  return dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "waZ89t3DTUla",
        "slideshow": {
          "slide_type": "-"
        }
      },
      "source": [
        "To show the use of control flow, we write the RNN loop by hand, rather than using a pre-built RNN model.\n",
        "\n",
        "Note how we write the model code in Eager style, with regular `if` and `while` statements. Then, we annotate the functions with `@autograph.convert` to have them automatically compiled to run in graph mode.\n",
        "We use Keras to define the model, and we will train it using Estimators."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "9v8AJouiC44V",
        "slideshow": {
          "slide_type": "slide"
        }
      },
      "outputs": [],
      "source": [
        "@autograph.convert()\n",
        "class RnnColorbot(tf.keras.Model):\n",
        "  \"\"\"RNN Colorbot model.\"\"\"\n",
        "\n",
        "  def __init__(self):\n",
        "    super(RnnColorbot, self).__init__()\n",
        "    self.lower_cell = tf.contrib.rnn.LSTMBlockCell(256)\n",
        "    self.upper_cell = tf.contrib.rnn.LSTMBlockCell(128)\n",
        "    self.relu_layer = tf.layers.Dense(3, activation=tf.nn.relu)\n",
        "\n",
        "\n",
        "  def _rnn_layer(self, chars, cell, batch_size, training):\n",
        "    \"\"\"A single RNN layer.\n",
        "\n",
        "    Args:\n",
        "      chars: A Tensor of shape (max_sequence_length, batch_size, input_size)\n",
        "      cell: An object of type tf.contrib.rnn.LSTMBlockCell\n",
        "      batch_size: Int, the batch size to use\n",
        "      training: Boolean, whether the layer is used for training\n",
        "\n",
        "    Returns:\n",
        "      A Tensor of shape (max_sequence_length, batch_size, output_size).\n",
        "    \"\"\"\n",
        "    hidden_outputs = []\n",
        "    autograph.utils.set_element_type(hidden_outputs, tf.float32)\n",
        "    state, output = cell.zero_state(batch_size, tf.float32)\n",
        "    for ch in chars:\n",
        "      cell_output, (state, output) = cell.call(ch, (state, output))\n",
        "      hidden_outputs.append(cell_output)\n",
        "    hidden_outputs = hidden_outputs.stack()\n",
        "    if training:\n",
        "      hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n",
        "    return hidden_outputs\n",
        "\n",
        "  def build(self, _):\n",
        "    \"\"\"Creates the model variables. See keras.Model.build().\"\"\"\n",
        "    self.lower_cell.build(tf.TensorShape((None, 256)))\n",
        "    self.upper_cell.build(tf.TensorShape((None, 256)))\n",
        "    self.relu_layer.build(tf.TensorShape((None, 128)))    \n",
        "    self.built = True\n",
        "\n",
        "\n",
        "  def call(self, inputs, training=False):\n",
        "    \"\"\"The RNN model code. Uses Eager and \n",
        "\n",
        "    The model consists of two RNN layers (made by lower_cell and upper_cell),\n",
        "    followed by a fully connected layer with ReLU activation.\n",
        "\n",
        "    Args:\n",
        "      inputs: A tuple (chars, length)\n",
        "      training: Boolean, whether the layer is used for training\n",
        "\n",
        "    Returns:\n",
        "      A Tensor of shape (batch_size, 3) - the model predictions.\n",
        "    \"\"\"\n",
        "    chars, length = inputs\n",
        "    batch_size = chars.shape[0]\n",
        "    seq = tf.transpose(chars, (1, 0, 2))\n",
        "\n",
        "    seq = self._rnn_layer(seq, self.lower_cell, batch_size, training)\n",
        "    seq = self._rnn_layer(seq, self.upper_cell, batch_size, training)\n",
        "\n",
        "    # Grab just the end-of-sequence from each output.\n",
        "    indices = tf.stack([length - 1, range(batch_size)], axis=1)\n",
        "    sequence_ends = tf.gather_nd(seq, indices)\n",
        "    return self.relu_layer(sequence_ends)\n",
        "\n",
        "@autograph.convert()\n",
        "def loss_fn(labels, predictions):\n",
        "  return tf.reduce_mean((predictions - labels) ** 2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JjK4gXFvFsf4",
        "slideshow": {
          "slide_type": "slide"
        }
      },
      "source": [
        "We will now create the model function for the custom Estimator.\n",
        "\n",
        "In the model function, we simply use the model class we defined above - that's it!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "-yso_Nx23Gy1",
        "slideshow": {
          "slide_type": "-"
        }
      },
      "outputs": [],
      "source": [
        "def model_fn(features, labels, mode, params):\n",
        "  \"\"\"Estimator model function.\"\"\"\n",
        "  chars = features['chars']\n",
        "  sequence_length = features['sequence_length']\n",
        "  inputs = (chars, sequence_length)\n",
        "\n",
        "  # Create the model. Simply using the AutoGraph-ed class just works!\n",
        "  colorbot = RnnColorbot()\n",
        "  colorbot.build(None)\n",
        "\n",
        "  if mode == tf.estimator.ModeKeys.TRAIN:\n",
        "    predictions = colorbot(inputs, training=True)\n",
        "    loss = loss_fn(labels, predictions)\n",
        "\n",
        "    learning_rate = params['learning_rate']\n",
        "    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)\n",
        "    global_step = tf.train.get_global_step()\n",
        "    train_op = optimizer.minimize(loss, global_step=global_step)\n",
        "    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)\n",
        "\n",
        "  elif mode == tf.estimator.ModeKeys.EVAL:\n",
        "    predictions = colorbot(inputs)\n",
        "    loss = loss_fn(labels, predictions)\n",
        "\n",
        "    return tf.estimator.EstimatorSpec(mode, loss=loss)\n",
        "\n",
        "  elif mode == tf.estimator.ModeKeys.PREDICT:\n",
        "    predictions = colorbot(inputs)\n",
        "\n",
        "    predictions = tf.minimum(predictions, 1.0)\n",
        "    return tf.estimator.EstimatorSpec(mode, predictions=predictions)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "HOQfoBnHC9CP",
        "slideshow": {
          "slide_type": "-"
        }
      },
      "source": [
        "We'll create an input function that will feed our training and eval data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "FJZlx7yG2MP0",
        "slideshow": {
          "slide_type": "slide"
        }
      },
      "outputs": [],
      "source": [
        "def input_fn(data_dir, data_url, params, training=True):\n",
        "  \"\"\"An input function for training\"\"\"\n",
        "  batch_size = params['batch_size']\n",
        "  \n",
        "  # load_dataset defined above\n",
        "  dataset = load_dataset(data_dir, data_url, batch_size, training=training)\n",
        "\n",
        "  # Package the pipeline end in a format suitable for the estimator.\n",
        "  labels, chars, sequence_length = dataset.make_one_shot_iterator().get_next()\n",
        "  features = {\n",
        "      'chars': chars,\n",
        "      'sequence_length': sequence_length\n",
        "  }\n",
        "\n",
        "  return features, labels"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "qsvv-lzbDqXd",
        "slideshow": {
          "slide_type": "-"
        }
      },
      "source": [
        "We now have everything in place to build our custom estimator and use it for training and eval!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "height": 35
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 10604,
          "status": "ok",
          "timestamp": 1524095272039,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "2pg1AfbxBJQq",
        "outputId": "9c924b4f-06e1-4538-976c-a3e1ddac5660",
        "slideshow": {
          "slide_type": "-"
        }
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Eval loss at step 100: 0.0674834\n"
          ]
        }
      ],
      "source": [
        "params = {\n",
        "    'batch_size': 64,\n",
        "    'learning_rate': 0.01,\n",
        "}\n",
        "\n",
        "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv\"\n",
        "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv\"\n",
        "data_dir = \"tmp/rnn/data\"\n",
        "\n",
        "regressor = tf.estimator.Estimator(\n",
        "    model_fn=model_fn,\n",
        "    params=params)\n",
        "\n",
        "regressor.train(\n",
        "    input_fn=lambda: input_fn(data_dir, train_url, params),\n",
        "    steps=100)\n",
        "eval_results = regressor.evaluate(\n",
        "    input_fn=lambda: input_fn(data_dir, test_url, params, training=False),\n",
        "    steps=2\n",
        ")\n",
        "\n",
        "print('Eval loss at step %d: %s' % (eval_results['global_step'], eval_results['loss']))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "zG1YAjB_cUnQ",
        "slideshow": {
          "slide_type": "slide"
        }
      },
      "source": [
        "And here's the same estimator used for inference."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "height": 343
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 7990,
          "status": "ok",
          "timestamp": 1524095280105,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "dxHex2tUN_10",
        "outputId": "2b889e5a-b9ed-4645-bf03-d98f26c72101",
        "slideshow": {
          "slide_type": "slide"
        }
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "\u003clink rel=stylesheet type=text/css href='/nbextensions/google.colab/tabbar.css'\u003e\u003c/link\u003e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.HTML at 0x7f3f36aa6cd0\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "outputarea_id1"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "\u003cscript src='/nbextensions/google.colab/tabbar_main.min.js'\u003e\u003c/script\u003e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.HTML at 0x7f3eca67f7d0\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "outputarea_id1"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "\u003cdiv id=\"id1\"\u003e\u003c/div\u003e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.HTML at 0x7f3eca67f8d0\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "outputarea_id1"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"e8ddfa22-4362-11e8-91ec-c8d3ffb5fbe0\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"elementId\": \"id1\", \"borderColor\": [\"#a7a7a7\"], \"contentHeight\": [\"initial\"], \"tabNames\": [\"RNN Colorbot\"], \"location\": \"top\", \"initialSelection\": 0});\n",
              "//# sourceURL=js_71b9087b6d"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3eca67f950\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "outputarea_id1"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"e8ddfa23-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n",
              "//# sourceURL=js_e390445f33"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3eca67f990\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "outputarea_id1"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"e8ddfa24-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n",
              "//# sourceURL=js_241dd76d85"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3eca67fc50\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"e8ddfa25-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n",
              "//# sourceURL=js_60c64e3d50"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3eca67fd90\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"e8ddfa26-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"e8ddfa25-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n",
              "//# sourceURL=js_14ea437cbd"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3eca67fe10\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"e8ddfa27-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n",
              "//# sourceURL=js_09294c2226"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3eca67fcd0\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"ec965514-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"e8ddfa24-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n",
              "//# sourceURL=js_e5e8266997"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3eca67fe10\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"ec965515-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n",
              "//# sourceURL=js_07a097f0ee"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3eca67fc90\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"ec965516-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n",
              "//# sourceURL=js_790d669ca8"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3eca67f8d0\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"ec965517-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec965516-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n",
              "//# sourceURL=js_d30df771f0"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3eca67fd90\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"ec965518-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n",
              "//# sourceURL=js_8a43a2da4b"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3eca67fc50\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQwAAAENCAYAAAD60Fs2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACMBJREFUeJzt3F+I1XX+x/G32zjiFERUpgaFd2JBzOg5joX4h0SiMgmM\n/uhVGIlgFBlERGB3hUEkhkRdtDfRP1ACL6KpLBqcguxCjEAkmGamQcSohFHzsxe7O6zssvsydtff\n+ns8rs758j3f8z7fiyef7/k3o7XWCiDwh4s9APC/QzCAmGAAMcEAYoIBxAQDiAkGF8XTTz9d3W63\n7rvvvhoZGakVK1Zc7JEICMYlbvXq1TU8PHyxxzjPV199VcPDw/XZZ5/V22+/XVVVM2bMuMhTkRAM\n/qt+++23+uGHH+r666+vWbNmXexxuECCcQl76qmnanx8vLZs2VIDAwP1+uuv1zfffFP3339/dTqd\nWr9+fY2MjEzvv2nTpnr55ZfrgQceqIGBgXr44Yfr5MmTVVV1+vTp2r59ey1durQ6nU5t2LChTpw4\nUVVVk5OTtWXLllq6dGmtXbu23nnnnelj7tq1q7Zt21bbt2+vJUuW1HvvvVfPPvtsHTp0qAYGBmrX\nrl1/N/fRo0dr06ZN1el06u67766hoaGqqhodHa1OpzO93zPPPFO33nrr9P3t27fXm2+++e89iZyv\ncUlbtWpVGx4ebq21NjEx0brdbjtw4EBrrbUvvviidbvdduLEidZaaxs3bmxr1qxp33//fZuammob\nN25sO3fubK219tZbb7VHH320TU1NtXPnzrXDhw+3X375pbXW2kMPPdR27NjRTp8+3Y4cOdIGBwen\nn/OVV15pN910U/voo49aa61NTU21999/vz344IPTMx48eLCtWLGitdbamTNn2po1a9qePXvamTNn\n2vDwcOvv72/Hjh2bfj2HDx9urbW2du3advvtt7ejR4+21lpbuXJlO3LkyH/qVNJas8L4f6D95edC\n+/btq5UrV9by5curqmrZsmV1880316effjq977333ls33HBD9fb21h133FFHjhypqqqenp46efJk\nHTt2rGbMmFGLFi2qyy+/vCYmJurrr7+uJ598smbOnFkLFy6sDRs21N69e6eP2d/fX6tXr66qqt7e\n3n8666FDh+rUqVP1yCOPVE9PTw0ODtaqVavqgw8+qKqqJUuW1MjISB0/fryqqtauXVtffvlljY6O\n1q+//loLFy78N501/pGeiz0A/z1jY2O1f//++vjjj6vqzyE5e/ZsLVu2bHqfa665Zvr27Nmz69Sp\nU1VVdc8999TExEQ98cQT9fPPP9e6devq8ccfr8nJybryyitr9uzZ04+bP39+HT58ePr+3Llz4xkn\nJydr3rx5522bP39+TU5OVlVVp9OpoaGhuu6666rb7Va32629e/dWb29vLV68+ALOBr+HYFzi/vbT\nh3nz5tX69etrx44dF3ycnp6e2rp1a23durXGxsZq8+bNtWDBgrrtttvqp59+qlOnTlVfX19VVY2P\nj9ecOXP+4Qz/ypw5c2p8fPy8bWNjY7VgwYKqqup2u/Xiiy/WvHnzqtPp1MDAQD333HPV29tb3W73\ngl8XF8YlySXu2muvrdHR0aqqWrduXQ0NDdXnn39e586dq6mpqRoZGakff/zxXx7n4MGD9d1339W5\nc+eqr6+venp66rLLLqu5c+dWf39/vfTSS3X69On69ttv6913361169b9rnlvueWW6uvrq9dee63O\nnj1bBw8erE8++aTuvPPOqqq68cYba9asWbVv377qdDp1xRVX1NVXX10ffvjheW+I8p8hGJe4zZs3\n1+7du6vb7db+/ftr9+7dtWfPnlq2bFmtWrWq3njjjen3OP7ZSuD48eO1bdu2Wrx4cd111121dOnS\n6Sjs3LmzRkdHa/ny5bVt27Z67LHHzrvMuRAzZ86sV199tQ4cOFCDg4P1/PPP1wsvvDC9wqj68yrj\nqquumr7U+WsoFi1a9Luek9yM1vyBDpCxwgBiggHEBAOICQYQ+z/7PYzjf/QRGVxM12z68u+2WWEA\nMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHE\nBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhAT\nDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEww\ngJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOI\nCQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAm\nGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhg\nADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIB\nxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQ\nEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBM\nMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHB\nAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQD\niAkGEBMMIDajtdYu9hDA/wYrDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4j9CY2LTAbbRbWuAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "\u003cmatplotlib.figure.Figure at 0x7f3ecc00bf10\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1",
              "user_output"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"ec965519-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec965515-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n",
              "//# sourceURL=js_893ad561f4"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3f31b55c90\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"ec96551a-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n",
              "//# sourceURL=js_2d99e0ac17"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3eca67fe50\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"ec96551b-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n",
              "//# sourceURL=js_5c19462e32"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3f31b55dd0\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"ec96551c-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec96551b-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n",
              "//# sourceURL=js_b9c8b7567b"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3f31b55a50\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"ec96551d-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n",
              "//# sourceURL=js_fd05186348"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3f31b55810\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "\u003cdiv class=id_888646481 style=\"margin-right:10px; display:flex;align-items:center;\"\u003e\u003cspan style=\"margin-right: 3px;\"\u003e\u003c/span\u003e\u003c/div\u003e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.HTML at 0x7f3f32414810\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1",
              "user_output"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"ec96551e-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 span\");\n",
              "//# sourceURL=js_efef96e882"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3f31b55710\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1",
              "user_output"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"ec96551f-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ec96551e-4362-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n",
              "//# sourceURL=js_6eca889864"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3eca67f990\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1",
              "user_output"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"ed8ea972-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 input\");\n",
              "//# sourceURL=js_f02070cc60"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3f31b553d0\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1",
              "user_output"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"ed8ea973-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ed8ea972-4362-11e8-91ec-c8d3ffb5fbe0\"].remove();\n",
              "//# sourceURL=js_ed9faba660"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3f31a95450\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1",
              "user_output"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"ed8ea974-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 span\");\n",
              "//# sourceURL=js_f3458d7074"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3f31a95250\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1",
              "user_output"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"ed8ea975-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ed8ea974-4362-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n",
              "//# sourceURL=js_3ffd97bd6f"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3f31a953d0\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1",
              "user_output"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"ed8ea976-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec96551a-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n",
              "//# sourceURL=js_7f73e8bcca"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7f3f31b55710\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id1_content_0",
              "outputarea_id1"
            ]
          },
          "output_type": "display_data"
        }
      ],
      "source": [
        "def predict_input_fn(color_name):\n",
        "  \"\"\"An input function for prediction.\"\"\"\n",
        "  _, chars, sequence_length = parse(color_name)\n",
        "\n",
        "  # We create a batch of a single element.\n",
        "  features = {\n",
        "      'chars': tf.expand_dims(chars, 0),\n",
        "      'sequence_length': tf.expand_dims(sequence_length, 0)\n",
        "  }\n",
        "  return features, None\n",
        "\n",
        "\n",
        "def draw_prediction(color_name, pred):\n",
        "  pred = pred * 255\n",
        "  pred = pred.astype(np.uint8)\n",
        "  plt.axis('off')\n",
        "  plt.imshow(pred)\n",
        "  plt.title(color_name)\n",
        "  plt.show()\n",
        "\n",
        "\n",
        "def predict_with_estimator(color_name, regressor):\n",
        "  predictions = regressor.predict(\n",
        "      input_fn=lambda:predict_input_fn(color_name))\n",
        "  pred = next(predictions)\n",
        "  predictions.close()\n",
        "  pred = np.minimum(pred, 1.0)\n",
        "  pred = np.expand_dims(np.expand_dims(pred, 0), 0)\n",
        "\n",
        "  draw_prediction(color_name, pred)\n",
        "\n",
        "tb = widgets.TabBar([\"RNN Colorbot\"])\n",
        "while True:\n",
        "  with tb.output_to(0):\n",
        "    try:\n",
        "      color_name = six.moves.input(\"Give me a color name (or press 'enter' to exit): \")\n",
        "    except (EOFError, KeyboardInterrupt):\n",
        "      break\n",
        "  if not color_name:\n",
        "    break\n",
        "  with tb.output_to(0):\n",
        "    tb.clear_tab()\n",
        "    predict_with_estimator(color_name, regressor)\n",
        "  "
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "default_view": {},
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      },
      "name": "RNN Colorbot using Keras and Estimators",
      "provenance": [
        {
          "file_id": "1CtzefX39ffFibX_BqE6cRbT0UW_DdVKl",
          "timestamp": 1523579810961
        },
        {
          "file_id": "1DcfimonWU11tmyivKBGVrbpAl3BIOaRG",
          "timestamp": 1523016192637
        },
        {
          "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K",
          "timestamp": 1522238054357
        },
        {
          "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ",
          "timestamp": 1521743157199
        },
        {
          "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-",
          "timestamp": 1520522344607
        }
      ],
      "version": "0.3.2",
      "views": {}
    },
    "kernelspec": {
      "display_name": "Python 2",
      "name": "python2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
